#!/usr/bin/env python  
#-*- coding:utf-8 _*-  
""" 
@author:hello_life 
@license: Apache Licence 
@file: main.py 
@time: 2022/05/03
@software: PyCharm 
description:
"""
import random

import torch
from torch.utils.data import DataLoader
import numpy as np

from utils.clue_dataset import CLUE_Dataset
from utils.train_utils import train_loop
from parameters.lstm_parameters import Config
from model.bilstm_crf import Bilstm_CRF

if __name__ == '__main__':
    #设定随机种子
    seed=2022
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    #参数加载
    config=Config()

    #模型加载
    model=Bilstm_CRF(config).to(config.device)

    #数据集加载
    train_dataset=CLUE_Dataset(config,config.train_data_save_path)
    train_dataloader=DataLoader(train_dataset,batch_size=config.batch_size,
                          shuffle=True)
    test_dataset=CLUE_Dataset(config,config.test_data_save_path)
    test_dataloader=DataLoader(test_dataset,batch_size=config.batch_size,
                               shuffle=True)
    #训练
    train_loop(model,train_dataloader,test_dataloader,config)